import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
import math
from functools import partial
from timm.layers import DropPath
try:
    from .shared_modules import RelativePositionBias, ContinuousPositionBias1D, MLP
except:
    from shared_modules import RelativePositionBias, ContinuousPositionBias1D, MLP
   
class LayerNorm(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.layernorm = nn.LayerNorm(*args, **kwargs)

    def forward(self, x):
        x = rearrange(x, "b c ... -> b ... c")
        x = self.layernorm(x)
        x = rearrange(x, "b ... c -> b c ...")
        return x

# Param builder func

    
def build_space_block(params):
    if params.space_type == 'axial_attention':
        return partial(AxialAttentionBlock, params.embed_dim, params.num_heads, bias_type=params.bias_type)
    else:
        raise NotImplementedError

### Space utils

class RMSInstanceNorm2d(nn.Module):
    def __init__(self, dim, affine=True, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.affine = affine
        if affine:
            self.weight = nn.Parameter(torch.ones(dim))
            self.bias = nn.Parameter(torch.zeros(dim)) # Forgot to remove this so its in the pretrained weights
    
    def forward(self, x):
        std, mean = torch.std_mean(x, dim=(-2, -1), keepdims=True)
        x = (x) / (std + self.eps)
        if self.affine:
            x = x * self.weight[None, :, None, None]  
        return x

    
class SubsampledLinear(nn.Module):
    """
    Cross between a linear layer and EmbeddingBag - takes in input 
    and list of indices denoting which state variables from the state
    vocab are present and only performs the linear layer on rows/cols relevant
    to those state variables
    
    Assumes (... C) input
    """
    def __init__(self, dim_in, dim_out, subsample_in=True):
        super().__init__()
        self.subsample_in = subsample_in
        self.dim_in = dim_in
        self.dim_out = dim_out
        temp_linear = nn.Linear(dim_in, dim_out)
        self.weight = nn.Parameter(temp_linear.weight)
        self.bias = nn.Parameter(temp_linear.bias)

    def inflation(self):
        weight = self.weight.data.detach()
        Vx = self.weight.data[:, 6].detach()
        Vy = self.weight.data[:, 7].detach()
        weight[:, 10] = (Vx+Vy)/2
        self.weight = nn.Parameter(weight)

        bias = self.bias.data.detach()
        Vx = self.bias.data[6].detach()
        Vy = self.bias.data[7].detach()
        bias[10] = (Vx+Vy)/2
        self.bias = nn.Parameter(bias)
    
    def forward(self, x, labels):
        # Note - really only works if all batches are the same input type
        labels = labels[0] # Figure out how to handle this for normal batches later
        label_size = len(labels)
        if self.subsample_in:
            scale = (self.dim_in / label_size)**.5 # Equivalent to swapping init to correct for given subsample of input
            x = scale * F.linear(x, self.weight[:, labels], self.bias)
        else:
            x = F.linear(x, self.weight[labels], self.bias[labels])
        return x

class hMLP_stem(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, patch_size=(16,16), in_chans=3, embed_dim =768):
        super().__init__()
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim

        patch_size = patch_size[0]
        p3 = int(patch_size**(1/4))
        p2 = p3
        p1 = patch_size // p2 // p3
        self.p3 = p3
        self.p2 = p2
        self.p1 = p1
        self.in_proj = torch.nn.Sequential(
            *[nn.Conv2d(in_chans, embed_dim//4, kernel_size=p1, stride=p1, bias=False),
            LayerNorm(embed_dim//4),
            nn.GELU(),
            nn.Conv2d(embed_dim//4, embed_dim//4, kernel_size=p2, stride=p2, bias=False),
            LayerNorm(embed_dim//4),
            nn.GELU(),
            nn.Conv2d(embed_dim//4, embed_dim, kernel_size=p3, stride=p3, bias=False),
            LayerNorm(embed_dim),
            ]
            )

    def forward(self, x):

        self.kernel1 = repeat(self.in_proj[0].weight, "o c h w -> o c h w d", d=self.p1) / self.p1
        mean, std = torch.std_mean(self.in_proj[0].weight, dim=(0, 1, 2, 3), keepdim=True)
        noise = 0.01*torch.randn_like(mean).detach()
        self.kernel1 = self.kernel1 + std.detach()*noise
        self.bias1 = self.in_proj[0].bias

        self.kernel2 = repeat(self.in_proj[3].weight, "o c h w -> o c h w d", d=self.p2) / self.p2
        mean, std = torch.std_mean(self.in_proj[3].weight, dim=(0, 1, 2, 3), keepdim=True)
        noise = 0.01*torch.randn_like(mean).detach()
        self.kernel2 = self.kernel2 + std.detach()*noise
        self.bias2 = self.in_proj[3].bias

        self.kernel3 = repeat(self.in_proj[6].weight, "o c h w -> o c h w d", d=self.p3) / self.p3
        mean, std = torch.std_mean(self.in_proj[6].weight, dim=(0, 1, 2, 3), keepdim=True)
        noise = 0.01*torch.randn_like(mean).detach()
        self.kernel3 = self.kernel3 + std.detach()*noise
        self.bias3 = self.in_proj[6].bias
   
        x = F.conv3d(x, self.kernel1, self.bias1, stride=self.p1)
        x = self.in_proj[1](x)
        x = self.in_proj[2](x)
        x = F.conv3d(x, self.kernel2, self.bias2, stride=self.p2)
        x = self.in_proj[4](x)
        x = self.in_proj[5](x)
        x = F.conv3d(x, self.kernel3, self.bias3, stride=self.p3)
        x = self.in_proj[7](x)
        return x
    
    
class hMLP_output(nn.Module):
    """ Patch to Image De-bedding
    """
    def __init__(self, patch_size=(16,16), out_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.out_chans = out_chans
        self.embed_dim = embed_dim

        patch_size = patch_size[0]
        p3 = int(patch_size**(1/4))
        p2 = p3
        p1 = patch_size // p2 // p3
        self.p1 = p1
        self.p2 = p2
        self.p3 = p3
        self.out_proj = torch.nn.Sequential(
            *[nn.ConvTranspose2d(embed_dim, embed_dim//4, kernel_size=p3, stride=p3, bias=False),
            LayerNorm(embed_dim//4),
            nn.GELU(),
            nn.ConvTranspose2d(embed_dim//4, embed_dim//4, kernel_size=p2, stride=p2, bias=False),
            LayerNorm(embed_dim//4),
            nn.GELU(),
            ])
        out_head = nn.ConvTranspose2d(embed_dim//4, out_chans, kernel_size=p1, stride=p1)
        self.out_kernel = nn.Parameter(out_head.weight)
        self.out_bias = nn.Parameter(out_head.bias)

    def inflation(self):
        out_kernel = self.out_kernel.data.detach()
        Vx = self.out_kernel.data[:, 6].detach()
        Vy = self.out_kernel.data[:, 7].detach()
        out_kernel[:, 10] = (Vx+Vy)/2
        self.out_kernel = nn.Parameter(out_kernel)
        out_bias = self.out_bias.data.detach()
        Vx = self.out_bias.data[6].detach()
        Vy = self.out_bias.data[7].detach()
        out_bias[10]  = (Vx+Vy)/2
        self.out_bias = nn.Parameter(out_bias)
    
    def forward(self, x, state_labels):
        self.kernel1 = repeat(self.out_proj[0].weight, "c o h w -> c o h w d", d=self.p3) / self.p3
        mean, std = torch.std_mean(self.out_proj[0].weight, dim=(0, 1, 2, 3), keepdim=True)
        noise = 0.01*torch.randn_like(mean).detach()
        self.kernel1 = self.kernel1 + std.detach()*noise

        self.kernel2 = repeat(self.out_proj[3].weight, "c o h w -> c o h w d", d=self.p2) / self.p2
        mean, std = torch.std_mean(self.out_proj[3].weight, dim=(0, 1, 2, 3), keepdim=True)
        noise = 0.01*torch.randn_like(mean).detach()
        self.kernel2 = self.kernel2 + std.detach()*noise

        out_kernel = repeat(self.out_kernel.data, "c o h w -> c o h w d", d=self.p1) / self.p1
        mean, std = torch.std_mean(self.out_kernel.data, dim=(0, 1, 2, 3), keepdim=True)
        noise = 0.01*torch.randn_like(mean).detach()
        out_kernel = out_kernel + std.detach()*noise

        x = F.conv_transpose3d(x, self.kernel1, None, stride=self.p3)
        x = self.out_proj[1](x)
        x = self.out_proj[2](x)
        x = F.conv_transpose3d(x, self.kernel2, None, stride=self.p2)
        x = self.out_proj[4](x)
        x = self.out_proj[5](x)
        x = F.conv_transpose3d(
            x, out_kernel[:, state_labels], self.out_bias[state_labels], stride=self.p1)
        return x
    
class AxialAttentionBlock(nn.Module):
    def __init__(self, hidden_dim=768, num_heads=12,  drop_path=0, layer_scale_init_value=1e-6, bias_type='rel'):
        super().__init__()
        self.num_heads = num_heads
        # self.norm1 = RMSInstanceNorm2d(hidden_dim, affine=True)
        self.norm1 = LayerNorm(hidden_dim)
        # self.norm2 = RMSInstanceNorm2d(hidden_dim, affine=True)
        self.norm2 = LayerNorm(hidden_dim)
        self.gamma_att = nn.Parameter(layer_scale_init_value * torch.ones((hidden_dim)), 
                            requires_grad=True) if layer_scale_init_value > 0 else None
        self.gamma_mlp = nn.Parameter(layer_scale_init_value * torch.ones((hidden_dim)), 
                            requires_grad=True) if layer_scale_init_value > 0 else None
        
        self.input_head = nn.Conv2d(hidden_dim, 3*hidden_dim, 1)
        self.output_head = nn.Conv2d(hidden_dim, hidden_dim, 1)
        self.qnorm = nn.LayerNorm(hidden_dim//num_heads)
        self.knorm = nn.LayerNorm(hidden_dim//num_heads)
        if bias_type == 'none':
            self.rel_pos_bias = lambda x, y: None
        elif bias_type == 'continuous':
            self.rel_pos_bias = ContinuousPositionBias1D(n_heads=num_heads)
        else:
            self.rel_pos_bias = RelativePositionBias(n_heads=num_heads)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()


        self.mlp = MLP(hidden_dim)
        # self.mlp_norm = RMSInstanceNorm2d(hidden_dim, affine=True)
        self.mlp_norm = LayerNorm(hidden_dim)

    def forward(self, x, bcs):
        # input is t x b x c x h x w 
        B, C, H, W, D = x.shape
        input = x.clone()
        x = self.norm1(x)
        weight = self.input_head.weight.unsqueeze(-1)
        bias = self.input_head.bias
        x = F.conv3d(x, weight, bias, stride=1)

        x = rearrange(x, 'b (he c) h w d ->  b he h w d c', he=self.num_heads)
        q, k, v = x.tensor_split(3, dim=-1)
        q, k = self.qnorm(q), self.knorm(k)

        # Do attention with current q, k, v matrices along each spatial axis then average results
        # X direction attention
        qx, kx, vx = map(lambda x: rearrange(x, 'b he h w d c ->  (b h w) he d c'), [q,k,v])
        rel_pos_bias_x = self.rel_pos_bias(D, D, bcs[0, 0])
        # Functional doesn't return attention mask :(
        if rel_pos_bias_x is not None:
            xx = F.scaled_dot_product_attention(qx, kx, vx, attn_mask=rel_pos_bias_x)
        else:
            xx = F.scaled_dot_product_attention(qx.contiguous(), kx.contiguous(), vx.contiguous())
        xx = rearrange(xx, '(b h w) he d c -> b (he c) h w d', h=H, w=W)

        # Y direction attention 
        qy, ky, vy = map(lambda x: rearrange(x, 'b he h w d c ->  (b h d) he w c'), [q,k,v])
        rel_pos_bias_y = self.rel_pos_bias(W, W, bcs[0, 1])

        if rel_pos_bias_y is not None:
            xy = F.scaled_dot_product_attention(qy, ky, vy, attn_mask=rel_pos_bias_y)
        else: # I don't understand why this was necessary but it was
            xy = F.scaled_dot_product_attention(qy.contiguous(), ky.contiguous(), vy.contiguous())
        xy = rearrange(xy, '(b h d) he w c -> b (he c) h w d', h=H, d=D)

        # Z direction attention
        qz, kz, vz = map(lambda x: rearrange(x, 'b he h w d c ->  (b w d) he h c'), [q,k,v])
        rel_pos_bias_z = self.rel_pos_bias(H, H, bcs[0, 1])

        if rel_pos_bias_z is not None:
            xz = F.scaled_dot_product_attention(qz, kz, vz, attn_mask=rel_pos_bias_z)
        else:
            xz = F.scaled_dot_product_attention(qz.contiguous(), kz.contiguous(), vz.contiguous())
        xz = rearrange(xz, '(b w d) he h c -> b (he c) h w d', w=W, d=D)

        # Combine
        x = (xx + xy + xz) / 3
        x = self.norm2(x)
        weight = self.output_head.weight.unsqueeze(-1)
        bias = self.output_head.bias
        x = F.conv3d(x, weight, bias, stride=1)
        x = self.drop_path(x*self.gamma_att[None, :, None, None, None]) + input

        # MLP
        input = x.clone()
        x = rearrange(x, 'b c h w d -> b h w d c')
        x = self.mlp(x)
        x = rearrange(x, 'b h w d c -> b c h w d')
        x = self.mlp_norm(x)
        output = input + self.drop_path(self.gamma_mlp[None, :, None, None, None] * x)

        return output